{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Image segmentation with a U-Net-like architecture\n",
    "\n",
    "**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
    "**Date created:** 2019/03/20<br>\n",
    "**Last modified:** 2020/04/20<br>\n",
    "**Description:** Image segmentation model trained from scratch on the Oxford Pets dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Download the data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!curl -O http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz\n",
    "!curl -O http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz\n",
    "!tar -xf images.tar.gz\n",
    "!tar -xf annotations.tar.gz\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Prepare paths of input images and target segmentation masks\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "input_dir = \"images/\"\n",
    "target_dir = \"annotations/trimaps/\"\n",
    "img_size = (160, 160)\n",
    "num_classes = 4\n",
    "batch_size = 32\n",
    "\n",
    "input_img_paths = sorted(\n",
    "    [\n",
    "        os.path.join(input_dir, fname)\n",
    "        for fname in os.listdir(input_dir)\n",
    "        if fname.endswith(\".jpg\")\n",
    "    ]\n",
    ")\n",
    "target_img_paths = sorted(\n",
    "    [\n",
    "        os.path.join(target_dir, fname)\n",
    "        for fname in os.listdir(target_dir)\n",
    "        if fname.endswith(\".png\") and not fname.startswith(\".\")\n",
    "    ]\n",
    ")\n",
    "\n",
    "print(\"Number of samples:\", len(input_img_paths))\n",
    "\n",
    "for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):\n",
    "    print(input_path, \"|\", target_path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## What does one input image and corresponding segmentation mask look like?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "from tensorflow.keras.preprocessing.image import load_img\n",
    "import PIL\n",
    "from PIL import ImageOps\n",
    "\n",
    "# Display input image #7\n",
    "display(Image(filename=input_img_paths[9]))\n",
    "\n",
    "# Display auto-constrast version of corresponding target (per-pixel categories)\n",
    "img = PIL.ImageOps.autocontrast(load_img(target_img_paths[9]))\n",
    "display(img)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Prepare `Sequence` class to load & vectorize batches of data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "from tensorflow.keras.preprocessing.image import load_img\n",
    "\n",
    "\n",
    "class OxfordPets(keras.utils.Sequence):\n",
    "    \"\"\"Helper to iterate over the data (as Numpy arrays).\"\"\"\n",
    "\n",
    "    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):\n",
    "        self.batch_size = batch_size\n",
    "        self.img_size = img_size\n",
    "        self.input_img_paths = input_img_paths\n",
    "        self.target_img_paths = target_img_paths\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.target_img_paths) // self.batch_size\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"Returns tuple (input, target) correspond to batch #idx.\"\"\"\n",
    "        i = idx * self.batch_size\n",
    "        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]\n",
    "        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]\n",
    "        x = np.zeros((batch_size,) + self.img_size + (3,), dtype=\"float32\")\n",
    "        for j, path in enumerate(batch_input_img_paths):\n",
    "            img = load_img(path, target_size=self.img_size)\n",
    "            x[j] = img\n",
    "        y = np.zeros((batch_size,) + self.img_size + (1,), dtype=\"uint8\")\n",
    "        for j, path in enumerate(batch_target_img_paths):\n",
    "            img = load_img(path, target_size=self.img_size, color_mode=\"grayscale\")\n",
    "            y[j] = np.expand_dims(img, 2)\n",
    "        return x, y\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Perpare U-Net Xception-style model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras import layers\n",
    "\n",
    "\n",
    "def get_model(img_size, num_classes):\n",
    "    inputs = keras.Input(shape=img_size + (3,))\n",
    "\n",
    "    ### [First half of the network: downsampling inputs] ###\n",
    "\n",
    "    # Entry block\n",
    "    x = layers.Conv2D(32, 3, strides=2, padding=\"same\")(inputs)\n",
    "    x = layers.BatchNormalization()(x)\n",
    "    x = layers.Activation(\"relu\")(x)\n",
    "\n",
    "    previous_block_activation = x  # Set aside residual\n",
    "\n",
    "    # Blocks 1, 2, 3 are identical apart from the feature depth.\n",
    "    for filters in [64, 128, 256]:\n",
    "        x = layers.Activation(\"relu\")(x)\n",
    "        x = layers.SeparableConv2D(filters, 3, padding=\"same\")(x)\n",
    "        x = layers.BatchNormalization()(x)\n",
    "\n",
    "        x = layers.Activation(\"relu\")(x)\n",
    "        x = layers.SeparableConv2D(filters, 3, padding=\"same\")(x)\n",
    "        x = layers.BatchNormalization()(x)\n",
    "\n",
    "        x = layers.MaxPooling2D(3, strides=2, padding=\"same\")(x)\n",
    "\n",
    "        # Project residual\n",
    "        residual = layers.Conv2D(filters, 1, strides=2, padding=\"same\")(\n",
    "            previous_block_activation\n",
    "        )\n",
    "        x = layers.add([x, residual])  # Add back residual\n",
    "        previous_block_activation = x  # Set aside next residual\n",
    "\n",
    "    ### [Second half of the network: upsampling inputs] ###\n",
    "\n",
    "    previous_block_activation = x  # Set aside residual\n",
    "\n",
    "    for filters in [256, 128, 64, 32]:\n",
    "        x = layers.Activation(\"relu\")(x)\n",
    "        x = layers.Conv2DTranspose(filters, 3, padding=\"same\")(x)\n",
    "        x = layers.BatchNormalization()(x)\n",
    "\n",
    "        x = layers.Activation(\"relu\")(x)\n",
    "        x = layers.Conv2DTranspose(filters, 3, padding=\"same\")(x)\n",
    "        x = layers.BatchNormalization()(x)\n",
    "\n",
    "        x = layers.UpSampling2D(2)(x)\n",
    "\n",
    "        # Project residual\n",
    "        residual = layers.UpSampling2D(2)(previous_block_activation)\n",
    "        residual = layers.Conv2D(filters, 1, padding=\"same\")(residual)\n",
    "        x = layers.add([x, residual])  # Add back residual\n",
    "        previous_block_activation = x  # Set aside next residual\n",
    "\n",
    "    # Add a per-pixel classification layer\n",
    "    outputs = layers.Conv2D(num_classes, 3, activation=\"softmax\", padding=\"same\")(x)\n",
    "\n",
    "    # Define the model\n",
    "    model = keras.Model(inputs, outputs)\n",
    "    return model\n",
    "\n",
    "\n",
    "# Free up RAM in case the model definition cells were run multiple times\n",
    "keras.backend.clear_session()\n",
    "\n",
    "# Build model\n",
    "model = get_model(img_size, num_classes)\n",
    "model.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Set aside a validation split\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "# Split our img paths into a training and a validation set\n",
    "val_samples = 1000\n",
    "random.Random(1337).shuffle(input_img_paths)\n",
    "random.Random(1337).shuffle(target_img_paths)\n",
    "train_input_img_paths = input_img_paths[:-val_samples]\n",
    "train_target_img_paths = target_img_paths[:-val_samples]\n",
    "val_input_img_paths = input_img_paths[-val_samples:]\n",
    "val_target_img_paths = target_img_paths[-val_samples:]\n",
    "\n",
    "# Instantiate data Sequences for each split\n",
    "train_gen = OxfordPets(\n",
    "    batch_size, img_size, train_input_img_paths, train_target_img_paths\n",
    ")\n",
    "val_gen = OxfordPets(batch_size, img_size, val_input_img_paths, val_target_img_paths)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train the model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Configure the model for training.\n",
    "# We use the \"sparse\" version of categorical_crossentropy\n",
    "# because our target data is integers.\n",
    "model.compile(optimizer=\"rmsprop\", loss=\"sparse_categorical_crossentropy\")\n",
    "\n",
    "callbacks = [\n",
    "    keras.callbacks.ModelCheckpoint(\"oxford_segmentation.h5\", save_best_only=True)\n",
    "]\n",
    "\n",
    "# Train the model, doing validation at the end of each epoch.\n",
    "epochs = 15\n",
    "model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Visualize predictions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Generate predictions for all images in the validation set\n",
    "\n",
    "val_gen = OxfordPets(batch_size, img_size, val_input_img_paths, val_target_img_paths)\n",
    "val_preds = model.predict(val_gen)\n",
    "\n",
    "\n",
    "def display_mask(i):\n",
    "    \"\"\"Quick utility to display a model's prediction.\"\"\"\n",
    "    mask = np.argmax(val_preds[i], axis=-1)\n",
    "    mask = np.expand_dims(mask, axis=-1)\n",
    "    img = PIL.ImageOps.autocontrast(keras.preprocessing.image.array_to_img(mask))\n",
    "    display(img)\n",
    "\n",
    "\n",
    "# Display results for validation image #10\n",
    "i = 10\n",
    "\n",
    "# Display input image\n",
    "display(Image(filename=val_input_img_paths[i]))\n",
    "\n",
    "# Display ground-truth target mask\n",
    "img = PIL.ImageOps.autocontrast(load_img(val_target_img_paths[i]))\n",
    "display(img)\n",
    "\n",
    "# Display mask predicted by our model\n",
    "display_mask(i)  # Note that the model only sees inputs at 150x150.\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "oxford_pets_image_segmentation",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}